BATS forecaster (multiple seasonality)#
BATS stands for Box–Cox transform, ARMA errors, Trend, and Seasonal components.
This notebook builds a practical BATS-style forecaster that supports multiple seasonalities (e.g., weekly + monthly), with a scikit-learn-like API:
BATS(use_box_cox=..., box_cox_bounds=..., use_trend=..., use_damped_trend=..., seasonal_periods=..., use_arma_errors=...)model = bats.fit(y)forecast = model.forecast(steps)
Implementation note: the original BATS model is formulated in state-space / exponential-smoothing form. Here we implement a BATS-style forecaster using:
explicit trend + seasonal design matrices, and
ARMA errors estimated via
statsmodels(SARIMAXwithd=0).
Model sketch (math)#
Let \(y_t\) be the observed series.
Box–Cox transform (optional)#
For \(y_t>0\) and parameter \(\lambda\): $\(g_\lambda(y_t) = \begin{cases} \dfrac{y_t^{\lambda}-1}{\lambda}, & \lambda \ne 0 \\ \log(y_t), & \lambda = 0 \end{cases}\)$
We model the transformed series \(x_t = g_\lambda(y_t)\).
Trend + multiple seasonalities#
\(f(t)=t\) for a linear trend.
For a simple damped trend option we use \(f(t)=\dfrac{1-\phi^t}{1-\phi}\) with damping \(\phi\in(0,1)\).
Each seasonal component \(S^{(k)}_t\) is encoded with seasonal dummies for period \(m_k\).
ARMA errors (optional)#
import warnings
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import os
import plotly.io as pio
from scipy import stats
import statsmodels.api as sm
warnings.filterwarnings("ignore", category=UserWarning)
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
pio.templates.default = "plotly_white"
rng = np.random.default_rng(7)
import numpy, pandas, scipy, statsmodels, plotly
print("numpy:", numpy.__version__)
print("pandas:", pandas.__version__)
print("scipy:", scipy.__version__)
print("statsmodels:", statsmodels.__version__)
print("plotly:", plotly.__version__)
numpy: 1.26.2
pandas: 2.1.3
scipy: 1.15.0
statsmodels: 0.14.4
plotly: 6.5.2
class BoxCoxTransformer:
def __init__(self, use_box_cox: bool, box_cox_bounds: tuple[float, float] = (0.0, 1.0)):
self.use_box_cox = bool(use_box_cox)
self.box_cox_bounds = tuple(float(v) for v in box_cox_bounds)
self.shift_: float = 0.0
self.lambda_: float | None = None
def fit(self, y: np.ndarray) -> "BoxCoxTransformer":
y = np.asarray(y, dtype=float)
if not self.use_box_cox:
self.shift_ = 0.0
self.lambda_ = None
return self
min_y = float(np.min(y))
self.shift_ = 0.0 if min_y > 0.0 else (1.0 - min_y)
y_pos = y + self.shift_
if np.any(y_pos <= 0.0):
raise ValueError("Box-Cox requires strictly positive data (even after shift).")
lo, hi = self.box_cox_bounds
self.lambda_ = float(stats.boxcox_normmax(y_pos, brack=(lo, hi), method="mle"))
return self
def transform(self, y: np.ndarray) -> np.ndarray:
y = np.asarray(y, dtype=float)
if not self.use_box_cox:
return y.copy()
if self.lambda_ is None:
raise RuntimeError("Call fit() before transform().")
y_pos = y + self.shift_
if np.any(y_pos <= 0.0):
raise ValueError("Box-Cox requires strictly positive data (even after shift).")
lmbda = float(self.lambda_)
if abs(lmbda) < 1e-10:
return np.log(y_pos)
return (np.power(y_pos, lmbda) - 1.0) / lmbda
def inverse_transform(self, x: np.ndarray) -> np.ndarray:
x = np.asarray(x, dtype=float)
if not self.use_box_cox:
return x.copy()
if self.lambda_ is None:
raise RuntimeError("Call fit() before inverse_transform().")
lmbda = float(self.lambda_)
if abs(lmbda) < 1e-10:
y_pos = np.exp(x)
else:
y_pos = np.power(lmbda * x + 1.0, 1.0 / lmbda)
return y_pos - self.shift_
def _acf(x: np.ndarray, max_lag: int) -> tuple[np.ndarray, np.ndarray]:
x = np.asarray(x, dtype=float)
x = x - x.mean()
denom = float(np.dot(x, x))
lags = np.arange(max_lag + 1)
values = np.zeros(max_lag + 1)
values[0] = 1.0
if denom == 0.0:
return lags, values
for k in range(1, max_lag + 1):
values[k] = float(np.dot(x[k:], x[:-k]) / denom)
return lags, values
def seasonal_dummies(t: np.ndarray, period: int, *, drop_first: bool = True) -> np.ndarray:
"""One-hot seasonal indicators for t mod period."""
t = np.asarray(t, dtype=int)
period = int(period)
if period <= 1:
return np.zeros((t.size, 0), dtype=float)
pos = t % period
n = t.size
k = period - 1 if drop_first else period
X = np.zeros((n, k), dtype=float)
if drop_first:
mask = pos != 0
X[np.arange(n)[mask], pos[mask] - 1] = 1.0
else:
X[np.arange(n), pos] = 1.0
return X
def trend_feature(t: np.ndarray, *, use_damped: bool, damped_phi: float) -> np.ndarray:
t = np.asarray(t, dtype=float)
if not use_damped:
return t
phi = float(damped_phi)
if not (0.0 < phi < 1.0):
raise ValueError("damped_phi must be in (0, 1)")
# f(t) = (1 - phi^t) / (1 - phi), with f(0)=0 and f(t) ~ t when phi -> 1
return (1.0 - np.power(phi, t)) / (1.0 - phi)
def bats_design_matrix(
t: np.ndarray,
*,
use_trend: bool,
use_damped_trend: bool,
damped_trend_phi: float,
seasonal_periods: list[int] | None,
) -> np.ndarray:
t = np.asarray(t, dtype=int)
cols = [np.ones((t.size, 1), dtype=float)]
if use_trend:
cols.append(trend_feature(t.astype(float), use_damped=use_damped_trend, damped_phi=damped_trend_phi).reshape(-1, 1))
if seasonal_periods:
for m in seasonal_periods:
cols.append(seasonal_dummies(t, period=int(m), drop_first=True))
return np.concatenate(cols, axis=1)
class BATSModel:
def __init__(
self,
*,
results,
transformer: BoxCoxTransformer,
use_trend: bool,
use_damped_trend: bool,
damped_trend_phi: float,
seasonal_periods: list[int] | None,
y_index,
):
self.results = results
self.transformer = transformer
self.use_trend = use_trend
self.use_damped_trend = use_damped_trend
self.damped_trend_phi = float(damped_trend_phi)
self.seasonal_periods = seasonal_periods
self.y_index = y_index
@property
def n_obs(self) -> int:
return int(self.results.nobs)
def fitted_values(self) -> np.ndarray:
fitted_x = np.asarray(self.results.fittedvalues, dtype=float)
return self.transformer.inverse_transform(fitted_x)
def residuals(self) -> np.ndarray:
# Residuals in the transformed space (more natural under Box-Cox)
return np.asarray(self.results.resid, dtype=float)
def forecast(self, steps: int, *, alpha: float = 0.05) -> dict[str, np.ndarray]:
steps = int(steps)
t_future = np.arange(self.n_obs, self.n_obs + steps)
X_future = bats_design_matrix(
t_future,
use_trend=self.use_trend,
use_damped_trend=self.use_damped_trend,
damped_trend_phi=self.damped_trend_phi,
seasonal_periods=self.seasonal_periods,
)
fcst = self.results.get_forecast(steps=steps, exog=X_future)
mean_x = np.asarray(fcst.predicted_mean, dtype=float)
ci = fcst.conf_int(alpha=alpha)
ci_np = np.asarray(ci)
lower_x = ci_np[:, 0]
upper_x = ci_np[:, 1]
mean_y = self.transformer.inverse_transform(mean_x)
lower_y = self.transformer.inverse_transform(lower_x)
upper_y = self.transformer.inverse_transform(upper_x)
return {"mean": mean_y, "lower": lower_y, "upper": upper_y}
class BATS:
def __init__(
self,
*,
use_box_cox: bool = False,
box_cox_bounds: tuple[float, float] = (0.0, 1.0),
use_trend: bool = True,
use_damped_trend: bool = False,
damped_trend_phi: float = 0.98,
seasonal_periods: list[int] | None = None,
use_arma_errors: bool = True,
arma_order: tuple[int, int] | None = (1, 1),
max_arma_order: int = 1,
show_warnings: bool = True,
):
self.use_box_cox = bool(use_box_cox)
self.box_cox_bounds = tuple(float(v) for v in box_cox_bounds)
self.use_trend = bool(use_trend)
self.use_damped_trend = bool(use_damped_trend)
self.damped_trend_phi = float(damped_trend_phi)
self.seasonal_periods = None if seasonal_periods is None else [int(m) for m in seasonal_periods]
self.use_arma_errors = bool(use_arma_errors)
self.arma_order = None if arma_order is None else (int(arma_order[0]), int(arma_order[1]))
self.max_arma_order = int(max_arma_order)
self.show_warnings = bool(show_warnings)
def _fit_sarimax(self, y_x: np.ndarray, X: np.ndarray, order: tuple[int, int]) -> tuple[object, float]:
p, q = order
res = sm.tsa.SARIMAX(
y_x,
exog=X,
order=(p, 0, q),
trend="n",
enforce_stationarity=True,
enforce_invertibility=True,
).fit(disp=False, method="lbfgs", maxiter=300)
return res, float(res.aic)
def _select_arma_order(self, y_x: np.ndarray, X: np.ndarray) -> tuple[int, int]:
candidates = []
for p in range(self.max_arma_order + 1):
for q in range(self.max_arma_order + 1):
candidates.append((p, q))
best_order = (0, 0)
best_aic = np.inf
for order in candidates:
try:
_, aic = self._fit_sarimax(y_x, X, order)
except Exception:
continue
if aic < best_aic:
best_aic = aic
best_order = order
if best_aic == np.inf:
raise RuntimeError("Failed to fit any ARMA(p,q) candidate.")
return best_order
def fit(self, y) -> BATSModel:
if isinstance(y, pd.Series):
y_index = y.index
y_np = y.to_numpy(dtype=float)
else:
y_index = None
y_np = np.asarray(y, dtype=float)
t = np.arange(y_np.size)
X = bats_design_matrix(
t,
use_trend=self.use_trend,
use_damped_trend=self.use_damped_trend,
damped_trend_phi=self.damped_trend_phi,
seasonal_periods=self.seasonal_periods,
)
transformer = BoxCoxTransformer(self.use_box_cox, box_cox_bounds=self.box_cox_bounds).fit(y_np)
y_x = transformer.transform(y_np)
if not self.use_arma_errors:
chosen_order = (0, 0)
elif self.arma_order is not None:
chosen_order = self.arma_order
else:
chosen_order = self._select_arma_order(y_x, X)
res, aic = self._fit_sarimax(y_x, X, chosen_order)
if self.show_warnings:
print(f"Chosen ARMA(p,q) = {chosen_order}, AIC = {aic:.2f}")
return BATSModel(
results=res,
transformer=transformer,
use_trend=self.use_trend,
use_damped_trend=self.use_damped_trend,
damped_trend_phi=self.damped_trend_phi,
seasonal_periods=self.seasonal_periods,
y_index=y_index,
)
Demo: synthetic series with two seasonalities#
We’ll simulate a daily series with:
weekly seasonality (\(m_1=7\))
~monthly seasonality (\(m_2=30\))
a small trend
ARMA-like correlated noise
def simulate_arma11(n: int, *, phi: float, theta: float, sigma: float, rng: np.random.Generator) -> np.ndarray:
eps = rng.normal(0.0, sigma, size=n)
u = np.zeros(n)
for t in range(n):
ar = phi * u[t - 1] if t - 1 >= 0 else 0.0
ma = theta * eps[t - 1] if t - 1 >= 0 else 0.0
u[t] = ar + eps[t] + ma
return u
n = 420
idx = pd.date_range("2020-01-01", periods=n, freq="D")
t = np.arange(n)
weekly = 2.0 * np.sin(2 * np.pi * t / 7) + 0.5 * np.cos(2 * np.pi * t / 7)
monthly = 1.2 * np.sin(2 * np.pi * t / 30) - 0.3 * np.cos(2 * np.pi * t / 30)
trend = 0.01 * t
noise = simulate_arma11(n, phi=0.6, theta=0.4, sigma=0.8, rng=rng)
y = 30.0 + trend + weekly + monthly + noise
y = pd.Series(y, index=idx, name="y")
fig = go.Figure()
fig.add_trace(go.Scatter(x=y.index, y=y, name="y", line=dict(color="black")))
fig.update_layout(title="Synthetic multi-seasonal series", xaxis_title="date", yaxis_title="value")
fig.show()
# Train/test split + fit
h = 60
y_train = y.iloc[:-h]
y_test = y.iloc[-h:]
bats = BATS(
use_box_cox=False,
box_cox_bounds=(0.0, 1.0),
use_trend=True,
use_damped_trend=False,
seasonal_periods=[7, 30],
use_arma_errors=True,
arma_order=(1, 1),
show_warnings=True,
)
model = bats.fit(y_train)
fcst = model.forecast(h)
fitted = pd.Series(model.fitted_values(), index=y_train.index)
pred_mean = pd.Series(fcst["mean"], index=y_test.index)
pred_lower = pd.Series(fcst["lower"], index=y_test.index)
pred_upper = pd.Series(fcst["upper"], index=y_test.index)
fig = go.Figure()
fig.add_trace(go.Scatter(x=y_train.index, y=y_train, name="train", line=dict(color="rgba(0,0,0,0.35)")))
fig.add_trace(go.Scatter(x=y_train.index, y=fitted, name="fitted", line=dict(color="#4E79A7")))
fig.add_trace(go.Scatter(x=y_test.index, y=y_test, name="test", line=dict(color="black")))
fig.add_trace(go.Scatter(x=y_test.index, y=pred_upper, line=dict(width=0), showlegend=False))
fig.add_trace(
go.Scatter(
x=y_test.index,
y=pred_lower,
fill="tonexty",
fillcolor="rgba(78,121,167,0.18)",
line=dict(width=0),
name="95% interval (approx)",
)
)
fig.add_trace(go.Scatter(x=y_test.index, y=pred_mean, name="forecast mean", line=dict(color="#E15759")))
fig.update_layout(title="BATS forecast on multi-seasonal series", xaxis_title="date", yaxis_title="value")
fig.show()
Chosen ARMA(p,q) = (1, 1), AIC = 841.30
# Residual diagnostics (in transformed space)
resid = model.residuals()
warmup = 10
resid_use = resid[warmup:]
print("residual mean:", float(resid_use.mean()))
print("residual std:", float(resid_use.std(ddof=1)))
print("Jarque-Bera:", stats.jarque_bera(resid_use))
lags, acf_vals = _acf(resid_use, max_lag=30)
bound = 1.96 / np.sqrt(resid_use.size)
# QQ data
nq = resid_use.size
p = (np.arange(1, nq + 1) - 0.5) / nq
theoretical = stats.norm.ppf(p)
sample_q = np.sort((resid_use - resid_use.mean()) / resid_use.std(ddof=1))
fig = make_subplots(
rows=2,
cols=2,
subplot_titles=("Residuals over time", "Residual histogram", "Residual ACF", "QQ plot (std residuals)"),
)
fig.add_trace(go.Scatter(x=y_train.index[warmup:], y=resid_use, name="residuals", line=dict(color="#4E79A7")), row=1, col=1)
fig.add_hline(y=0, line=dict(color="black", dash="dash"), row=1, col=1)
fig.add_trace(go.Histogram(x=resid_use, nbinsx=30, name="hist", marker_color="#4E79A7"), row=1, col=2)
fig.add_trace(go.Bar(x=lags, y=acf_vals, name="ACF(resid)", marker_color="#4E79A7"), row=2, col=1)
fig.add_trace(go.Scatter(x=[0, lags.max()], y=[bound, bound], mode="lines", line=dict(color="gray", dash="dash"), showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=[0, lags.max()], y=[-bound, -bound], mode="lines", line=dict(color="gray", dash="dash"), showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=theoretical, y=sample_q, mode="markers", name="QQ", marker=dict(color="#4E79A7")), row=2, col=2)
fig.add_trace(
go.Scatter(x=[theoretical.min(), theoretical.max()], y=[theoretical.min(), theoretical.max()], mode="lines", line=dict(color="black", dash="dash"), showlegend=False),
row=2,
col=2,
)
fig.update_layout(height=750, title="BATS residual diagnostics")
fig.show()
residual mean: -0.003120780574746758
residual std: 0.7024190892644842
Jarque-Bera: SignificanceResult(statistic=0.20620374582574355, pvalue=0.9020350758599663)